'''
 * @ author     ：廖传港
 * @ date       ：Created in 2020/10/22 10:10
 * @ description：
 * @ modified By：
 * @ ersion     : 
 * @File        : get_files.py 
'''

import tensorflow as tf
import os
import numpy as np


def get_files(file_dir):
    cats = []
    label_cats = []
    dogs = []
    label_dogs = []
    for file in os.listdir(file_dir):
        name = file.split(sep='.')
        if 'cat' in name[0]:
            cats.append(file_dir + file)
            label_cats.append(0)
        else:
            if 'dog' in name[0]:
                dogs.append(file_dir + file)
                label_dogs.append(1)
        image_list = np.hstack((cats, dogs))
        label_list = np.hstack((label_cats, label_dogs))
    # print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs)))
    # 多个种类分别的时候需要把多个种类放在一起，打乱顺序,这里不需要

    # 把标签和图片都放倒一个 temp 中 然后打乱顺序，然后取出来
    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    # 打乱顺序
    np.random.shuffle(temp)

    # 取出第一个元素作为 image 第二个元素作为 label
    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]
    return image_list, label_list


# 测试 get_files
# imgs , label = get_files('/Users/LCG/Desktop/train/test')
# for i in imgs:
# 	print("img:",i)
#
# for i in label:
# 	print('label:',i)